package com.linkedin.dagli.examples.fasttextandavro;

import com.linkedin.dagli.dag.DAG;
import com.linkedin.dagli.dag.DAG1x1;
import com.linkedin.dagli.evaluation.MultinomialEvaluation;
import com.linkedin.dagli.evaluation.MultinomialEvaluationResult;
import com.linkedin.dagli.math.distribution.DiscreteDistribution;
import com.linkedin.dagli.fasttext.FastTextClassification;
import com.linkedin.dagli.object.Convert;
import com.linkedin.dagli.objectio.avro.AvroReader;
import com.linkedin.dagli.objectio.ObjectReader;
import com.linkedin.dagli.objectio.SampleSegment;
import com.linkedin.dagli.text.token.Tokens;
import java.io.File;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.net.URISyntaxException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.util.Locale;


/**
 * This is an example of a simple Dagli ML pipeline that learns a model for predicting the character (e.g. Caesar,
 * Hamlet, etc.) from a line the character utters in a Shakespearean play.
 *
 * - Data is read from an Avro source.
 * - The statistical model used is FastText.
 *
 * The example is very close to what a simple initial model using FastText might look like with an Avro data source.
 * This example can optionally write the model to a file; you can use {@link ShakespeareCLI} to load that model and do
 * inference with it.
 *
 * If you are trying to run this example in an IDE, please note that you must have annotation processing enabled (see
 * {@code dagli/documentation/structs.md}).
 */
public class FastTextExample {
  /**
   * Trains a model, evaluates it, and (optionally) saves it.  See ShakespeareCLI (in this directory) for code that
   * loads the saved model and does inference with it.
   *
   * A path to write the model to may be given as an argument.  Data is read from a resource file.  In a real
   * application you'd probably read your data from a normal file (or the network) rather than a resource file, of
   * course.
   *
   * Here's what the code does:
   * (1) Uses {@link AvroReader} to read our Avro data.  Reading doesn't actually happen until training starts, as data
   *     will be streamed from disk rather than read into RAM.  This can be important when training on large datasets.
   * (2) Calls createDAG() to get the unprepared, trainable DAG.
   * (3) Trains the DAG.
   * (4) Evaluates the trained DAG.
   * (4) Serializes the trained DAG to disk.
   *
   * @param arguments empty array, or a single-element array with the path to which to write the model.  If this file
   *                  already exists the program will immediately terminate rather than overwrite it.
   */
  public static void main(String[] arguments) throws IOException, URISyntaxException, ClassNotFoundException {
    // let's get the model path, verify that it doesn't exist yet, and create the file:
    final Path modelPath = modelPathFromArguments(arguments);

    // Avro doesn't support reading directly from a resource file, so we'll need to copy it into a temp file first.
    // In the real world, of course, the data would already be in a regular file and this wouldn't be necessary.
    Path avroPath = Files.createTempFile("shakespeare", ".avro");
    Files.copy(FastTextExample.class.getResourceAsStream("/shakespeare.avro"), avroPath,
        StandardCopyOption.REPLACE_EXISTING);
    new File(avroPath.toString()).deleteOnExit();

    // Now we can create our AvroReader to read our examples as CharacterDialogStructs.  CharacterDialogStructs are
    // @Structs that extend the auto-generated Avro class CharacterDialog, and are thus valid Avro objects with all the
    // various autogenerated methods and inner class goodies that @Structs provide.  See package-info.java for how we
    // set this up.
    try (AvroReader<CharacterDialogStruct> examples = new AvroReader<>(CharacterDialogStruct.class, avroPath)) {
      // We've got our data, but we want to split it into training and evaluation data sets.  A convenient
      // way to do that is using "SampleSegments".  These partition the space of examples in random sets by assigning
      // a random number between 0 and 1 to every example, and then figuring out if that number falls into the segment's
      // range: if it is, it's part of the sample, otherwise, it's not.  Sampling the same segment, with the same seed
      // (here we don't specify a seed and just use the default, constant seed value), from the same data always yields
      // the same subset of data.
      final SampleSegment trainingSegment = new SampleSegment(0.0, 0.8); // use 80% of the data for training
      final SampleSegment evaluationSegment = trainingSegment.complement(); // and the rest for evaluation

      // Now we're ready to train.  Because inference in FastText can be relatively expensive when the number of labels
      // is large, we can train the model by calling the "prepare(...)" method on the DAG (preparing the DAG will train
      // the model using the provided examples).  If we instead called prepareAndApply(...) we'd get the trained model
      // as well as the inferences on the training examples.
      DAG1x1.Prepared<CharacterDialogStruct, DiscreteDistribution<String>> predictor =
          createDAG().prepare(examples.sample(trainingSegment));

      // Next we can use our model to predict labels for our held-out evaluation data, so we can see how well the model
      // will do on examples that it hasn't seen during training.  The result of a prediction is a DiscreteDistribution
      // (mapping possible character names to their probabilities) so we use lazyMap(...) to get the most likely
      // character name.
      ObjectReader<String> predictedCharacterNames = predictor.applyAll(examples.sample(evaluationSegment))
          .lazyMap(d -> d.mostLikelyLabel().orElse(null));

      // Now we can pull out the actual, true character names for each of our evaluation examples.  Note that these are
      // CharSequences, so we need to convert them to Strings so we can compare them to our predictions.
      ObjectReader<String> actualCharacterNames = examples.sample(evaluationSegment)
          .lazyMap(CharacterDialogStruct::getCharacter)
          .lazyMap(CharSequence::toString);

      // We can use MultinomialEvaluation to now compare the actual and predicted character names to determine our
      // performance; MultinomialEvaluation is also a transformer that can be used to do evaluation as part of a DAG,
      // but using this static method is more convenient for us.
      MultinomialEvaluationResult evaluation =
          MultinomialEvaluation.evaluate(actualCharacterNames, predictedCharacterNames);

      // Print out the evaluation:
      System.out.println("\nModel Evaluation");
      System.out.println("----------------");
      System.out.println(evaluation.getSummary() + "\n");

      // So now we've trained the model and printed out evaluation.  What's left?  We need to save it so we can do
      // inference later.  DAGs are Java-serializable, so this is quite easy:
      if (modelPath != null) {
        try (ObjectOutputStream oos = new ObjectOutputStream(Files.newOutputStream(modelPath))) {
          oos.writeObject(predictor);
        }
        System.out.println("Saved model to " + modelPath.toString() + "\n");
      }
    }
    // And we're done!  There are, of course, many ways to deploy your serialized model, but the easiest is probably
    // to just put it into a resource file and then load it up with Java deserialization.  See the ShakespeareCLI
    // class for just such an example.
  }

  /**
   * Gets the model path (if any) from the arguments, checks that it doesn't already exist, and creates the file that
   * will be written to.
   *
   * @param arguments the arguments passed to the program
   * @return a Path for the newly-created model file
   * @throws IOException
   */
  private static Path modelPathFromArguments(String[] arguments) throws IOException {
    if (arguments.length == 0) {
      return null;
    } else {
      Path res = Paths.get(arguments[0]);
      if (Files.exists(res)) {
        System.out.println("The output path provided to save the model already exists: " + arguments[0]);
        System.out.println("For safety, this program will not overwrite an existing file.  Aborting.");
        System.exit(1);
      }
      Files.createFile(res); // create the file to make sure it's createable; we'll write to it later
      return res;
    }
  }

  /**
   * This method returns the DAG we'll use to define our model.  Its a model that takes a CharacterDialogStruct as its
   * input and returns the probabilities of each Shakespeare character (the predicted discrete distribution).
   *
   * Please note that, at inference-time, the name of the Shakespeare character in CharacterDialogStruct (the true
   * label) will be null.  This isn't a problem since FastText doesn't use the label at inference-time.
   *
   * At training-time, the output of the DAG is useful for auto-evaluation (e.g. determining how well the model does
   * on its own training data); this can be useful for gauging if the model is expressive enough to capture the data and
   * for determining how much it's overfitting (high auto-evaluation metrics and much lower metrics on unseen data
   * suggest that the model is overfitting to the training data).
   *
   * @return a text classification DAG that accepts the true label (String) and the line of dialog (String) and produces
   *         a discrete distribution over the labels.
   */
  private static DAG1x1<CharacterDialogStruct, DiscreteDistribution<String>> createDAG() {
    // Define a "ploaceholder" in the DAG.  When the DAG is executed, placeholder values are provided as inputs.  If we
    // view the DAG as consuming a sequence of "rows", where each row is an example, each placeholder is a "column".
    // We're going to use CharacterDialogStruct.Placeholder which has some useful convenience methods.  We could also
    // just use Placeholder<CharacterDialogStruct> (which CharacterDialogStruct.Placeholder extends) but our code would
    // end up being slightly more verbose in the next few lines.
    CharacterDialogStruct.Placeholder characterDialog = new CharacterDialogStruct.Placeholder();

    // FastText requires the text to be tokenized (broken into a list of words), so let's do so now.
    // "characterDialog.asDialog()" is shorthand for "new CharacterDialogStruct.Dialog().withInput(characterDialog)":
    // it just pulls out the dialog text from the CharacterDialogStruct objects from the characterDialog source.
    Tokens dialogTokens = new Tokens().withLocale(Locale.ENGLISH).withTextInput(characterDialog.asDialog());

    // We're ready to configure FastText itself; we'll set up a few hyperparameters and specify the labels and tokens
    // inputs (these are not the *optimal* hyperparameters by any stretch--you may want to experiment with others!)
    //
    // However, we have a problem: serializing the model requires serializing the labels.  But the character field in
    // the CharacterDialog Avro class is of type CharSequence, which does not implement Serializable.
    // To solve this, we can simply transformer the character label by:
    // (1) using asCharacter() to pull the "character" field out of characterDialog
    // (2) using Convert.toString(...) to convert "character" to a String, which *is* serializable.
    // (Convert.toString(...) simply creates a transformer that calls toString() on its input)
    FastTextClassification<String> fastTextClassification =
        new FastTextClassification<String>()
            .withLabelInput(Convert.toString(characterDialog.asCharacter()))
            .withTokensInput(dialogTokens)
            .withEmbeddingLength(64)
            .withMinTokenCount(1)
            .withMaxWordNgramLength(2) // unigrams and bigrams only
            .withEpochCount(200)
            .withBucketCount(200000); // fewer buckets than default to save RAM; don't use this value for real problems

    // Note: we could use BestModel to do cross-validation to search over possible hyperparameter values for FastText
    // and find the best model variant, and given that our data set is relatively small this is very practical.
    // However, this is beyond the scope of this example.

    // For now, we're done!  Build the DAG by specifying the output and the required placeholder input:
    return DAG.withPlaceholder(characterDialog).withOutput(fastTextClassification);
  }

  private FastTextExample() { } // nobody will be creating instances of this class
}
